Skip to content

Conversation

@skyw
Copy link
Contributor

@skyw skyw commented Oct 13, 2025

Decided to remove split fused parameters logic altogether, because how parameters are fused/stacked together is implementation dependent, it is hard to generalize for everything.
Instead, now it provides interface to plugin more sophisticated orthogonalize function and let user control how to split.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 13, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@skyw skyw force-pushed the skyw/generalize_fused_weight branch from 11e032a to c178725 Compare October 13, 2025 22:05
@skyw skyw requested a review from FDecaYed October 13, 2025 22:05
@skyw
Copy link
Contributor Author

skyw commented Oct 13, 2025

Feels like it may be heading the wrong direction, there will be more cases to support. Alternative is just take an orthoganize function and let users control how to split inside, as well as scale function and all rest of it.

@FDecaYed @mkhona-nvidia let me know what do you think?

@skyw
Copy link
Contributor Author

skyw commented Oct 13, 2025

/ok to test c178725

FDecaYed
FDecaYed previously approved these changes Oct 14, 2025
Copy link

@FDecaYed FDecaYed left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

split_grads_whitened = [self.orthogonalize_fn(g) for g in split_grads]
split_grad_scales = [self.scale_factor_fn(g.size(0), g.size(1)) for g in split_grads]

# TODO(skyw): Revisit whether there are cases that concatenating is not done along dim=0.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nn.conv1d (https://docs.pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) has 3d filter and so the output has to be reshaped to 3d

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid point, that's also one more reason to let user supply orthogonalize function altogether instead of trying to generalize for everything.

Although conv specifically is a completely different case, all rest code assumes 2d, the scale function for example.

@skyw
Copy link
Contributor Author

skyw commented Oct 15, 2025

/ok to test 27aa01c

@mkhona-nvidia mkhona-nvidia marked this pull request as ready for review October 17, 2025 18:12
@skyw
Copy link
Contributor Author

skyw commented Oct 17, 2025

/ok to test bd33fbb

Copy link
Contributor

@mkhona-nvidia mkhona-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@skyw skyw merged commit d6b8fa6 into main Oct 17, 2025
14 checks passed
@skyw skyw deleted the skyw/generalize_fused_weight branch October 17, 2025 18:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants